{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 使用生成对抗网络实现图像转换\n",
    "\n",
    "**此案例使用GPU算力，请参照注意事项完成规格切换**\n",
    "\n",
    "## 注意事项：\n",
    "\n",
    "1. 本案例使用AI引擎**:** **TensorFlow-1.13.1**\n",
    "\n",
    "2. 本案例最低硬件规格要求**:**  **类型选择GPU，目标规格选择8U + 64GiB + 1GPU**\n",
    "\n",
    "3. 切换硬件规格方法**:**  如需切换硬件规格,您可以在本页面右边的工作区进行切换\n",
    "\n",
    "4. 运行代码方法**:**  点击本页面顶部菜单栏的三角形运行按钮或按Ctrl+Enter键 运行每个方块中的代码\n",
    "\n",
    "5. JupyterLab的详细用法**:**  [请参考《ModelAtrs JupyterLab使用指导》](https://bbs.huaweicloud.com/forum/thread-97603-1-1.html)\n",
    "\n",
    "6. Kernel Restaring,Kernel died及其他常见问题的解决办法**:** [请参考《ModelAtrs JupyterLab常见问题解决办法》](https://bbs.huaweicloud.com/forum/thread-98681-1-1.html)\n",
    "\n",
    "## 案例内容介绍\n",
    "\n",
    "生成对抗网络（Generative Adversarial Nets，以下简称GAN）是Ian J. Goodfellow等人在2014年的论文[Generative Adversarial Nets](https://arxiv.org/pdf/1406.2661.pdf)中提出的一种评估生成模型的方法。GAN中巧妙的提出了以对抗的方法来训练生成和判别两个模型，在两者的对抗过程中不断改进，最终达到一种平衡。神经网络奠基人之一的Yann LeCun称赞对抗训练是“在过去10年中最有趣的机器学习想法”。\n",
    "\n",
    "### 生成模型（generator）与判别模型（discrimator）\n",
    "GAN网络（生成对抗网络），可以认为是一个造假机器，造出来的东西跟真的一样。打比方说生成模型就相当于造假钞的，判别模型相当于银行，生成模型造好假钞之后拿到银行去存，银行（判别模型）说：“ 不好意思，先生，你这是假钞，你看你的钞票上连毛爷爷都没有”。生成模型不服气，回来继续造，他就把毛爷爷印上去了，又去了银行,银行又说：“对不起，先生，你这是假钞，我们这100元的没有黄色的，都是红色的”。生成模型回来后，又把假钞印成红色，互相对抗，互相博弈，反复数次之后，生成模型已经把假钞造得跟真钞一样了，至少银行（判别模型）已经认不出来了。这时候，双方就达到了一种动态的平衡的状态。\n",
    "\n",
    "- 生成模型的主要任务是：生成尽可能逼近真样本的假样本\n",
    "\n",
    "- 判别模型的主要任务是：尽可能高准确度的判别假样本\n",
    "\n",
    "#### 相同点\n",
    "这两个模型都可以看成是一个黑匣子，接受输入然后有一个输出，类似于一个函数，一个输入输出的映射。\n",
    "\n",
    "#### 不同点\n",
    "生成模型功能：比作是一个样本生成器，输入一个噪声/样本，然后把它包装成一个逼真的样本，也就是输出。\n",
    "\n",
    "判别模型：比作一个二分类器（如同0-1分类器），来判断输入的样本是真是假。（就是输出值大于0.5还是小于0.5）。\n",
    "\n",
    "### 单独交替迭代训练\n",
    "1.训练判别模型：\n",
    "- 假样本（0）：给一堆随机数组，就会得到一堆假的样本集；真样本（1）：真实数据\n",
    "- 就是二分类模型的训练，直接将数据喂入判别模型即可\n",
    "\n",
    "2.训练生成模型：\n",
    "- 生成模型生成假样本，但是只用生成模型自己的话没有误差反馈，也就是说没有误差来源，无法训练。所以将上一步训练好的判别模型连接在生成模型后一起训练。\n",
    "- 如何连接？因为两个模型的输入都输同样大小的数据，比如说人脸照片，生成的照片数据就可以作为判别模型的输入。\n",
    "- 训练数据是什么？第一步中我们产生的假样本，将其标签设置为1（就是当做真样本），注意：生成模型的数据只有这个假样本。这样才能迷惑判别器，逐渐使生成的假样本逼近真样本\n",
    "- 训练时另外重要的一点是，我们要冻结判别模型的参数。判别模型的作用只是传递误差，而不是联合训练。（好不容易训练好的判别摸性，再重新训练不就是白训练了吗，所有不能修改判别模型的参数）\n",
    "\n",
    "3.生成模型训练好之后，我们的假样本更加逼真，这时再将这些逼真的假样本当做假样本，真实数据为真样本进行步骤1，2的训练。判别模型的判别能力越来越高，生成模型生成的假样本也越来越逼真，这就叫做单独交替迭代训练。\n",
    "\n",
    "现在市面上有许多好玩的案例都是基于GAN来实现的，比如给黑白照片上色，3D对象生成等等。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### GAN的扩展\n",
    "\n",
    "GAN论文发表后，在业界引起了广泛的研究和探讨，产生了很多基于GAN的扩展模型，例如使用深度卷积网络定义模型的[DCGAN](https://arxiv.org/abs/1511.06434)，基于GAN的图片转换方法[pix2pix](https://arxiv.org/abs/1611.07004)等，并产生了很多有趣的应用。接下来我们以[CycleGAN](https://arxiv.org/abs/1703.10593)为例，介绍GAN的相关应用。\n",
    "\n",
    "#### CycleGAN\n",
    "\n",
    "CycleGAN是在2017年的论文《Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks\n",
    "》中提出的，从论文标题可以看出，CycleGAN提出了一种不需要配对的图像转换方法。在之前的相关研究中，一般在训练时需要提供一套匹配的图像数据才能实现图像的转换，而CycleGAN则是直接在两种不同风格的图像数据集中学习特征，然后尝试用GAN的方法训练模型，最终让模型可以把一种数据集中的风格特征套用到另一种数据集中。\n",
    "\n",
    "\n",
    "在CycleGAN中，我们假设有两套不同风格的图像数据$X$和$Y$（如一组是实拍风景，另一组是油画，两者内容不必匹配），按照GAN的思路，我们可以构建一个生成模型G，对$X\\to{Y}$进行建模，然后构建一个判别模型D，判断图片是否是G生成的伪造图像。然而，CycleGAN的作者在实践中发现，$X \\to{Y}$的映射是有多种解的，很难保证两者的映射关系在视觉上是有意义的，并且常规的GAN手段也容易引起转换效果的下降（[mode collapse](https://aiden.nibali.org/blog/2017-01-18-mode-collapse-gans/)）：$X$的数据分布被映射到$Y$中的少数几个实例上。\n",
    "\n",
    "CycleGAN的作者提出，我们不仅保证映射关系$G: X \\to{Y}$，还可以引入另一个映射关系$F: Y\\to{X}$，在训练时，我们同时训练两种映射$G和F$，并引入一个新的\"循环一致性损失\"(cycle consistent loss)，即：对映射好的关系，再用反向映射，使两次映射的结果向原输入拟合：\n",
    "\n",
    "$$F(G(x)) \\approx {x},     G(F(y)) \\approx {y}$$\n",
    "\n",
    "这样，综合生成对抗损失和循环一致性损失，我们就可以实现无需配对的图像转换。这种方法具备更好的泛化性，可以用于图片风格转换，物体变换，图片增强等场景中。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 实现CycleGAN\n",
    "\n",
    "下面，我们实现一个简单的CycleGAN结构，并用于物体转换（苹果和橘子的转换）。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 实验目标\n",
    "\n",
    "1.掌握CycleGAN的概念和原理。\n",
    "\n",
    "2.掌握如何使用华为云ModelArts的Jupyter Notebook下载上传数据，执行python代码。\n",
    "\n",
    "3.掌握TensorFlow框架和keras的一些库的使用。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 实验步骤"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "首先，准备数据和工具代码，**由于我们安装了新的库keras_contrib，执行完下面cell后需要重启一下jupyter kernel。**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import moxing as mox\n",
    "if not os.path.exists('./gan.tar.gz'):\n",
    "    mox.file.copy_parallel('obs://modelarts-labs-bj4/course/ai_in_action/2021/GAN/GAN/gan.tar.gz', './gan.tar.gz')\n",
    "if os.path.exists('./gan.tar.gz'):\n",
    "    ! tar zxf gan.tar.gz\n",
    "    ! cd data/keras-contrib-master/ && python setup.py install"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 注意：重启kernel操作为点击本页面左上角的“kernel”按钮，找到\"Restart Kernel\"命令，点击它，并在弹出框里点击“Restart\"即可,重启后，无需重复执行上面代码，从下面开始执行即可。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "接下来，定义generator和discriminator的构造方法："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from keras.layers import Conv2D, LeakyReLU, UpSampling2D, Dropout, Concatenate, Input\n",
    "from keras_contrib.layers.normalization.instancenormalization import InstanceNormalization\n",
    "from keras.models import Model\n",
    "\n",
    "img_rows = 128\n",
    "img_cols = 128\n",
    "channels = 3\n",
    "img_shape = (img_rows, img_cols, channels)\n",
    "\n",
    "patch = int(img_rows / 2**4)\n",
    "disc_patch = (patch, patch, 1)\n",
    "\n",
    "# Number of filters in the first layer of G and D\n",
    "gf = 32\n",
    "df = 64\n",
    "\n",
    "# Loss weights\n",
    "lambda_cycle = 10.0               # Cycle-consistency loss\n",
    "lambda_id = 0.1 * lambda_cycle    # Identity loss\n",
    "\n",
    "\n",
    "def build_generator():\n",
    "    def conv2d(layer_input, filters, f_size=4):\n",
    "        \"\"\"Layers used during downsampling\"\"\"\n",
    "        d = Conv2D(filters, kernel_size=f_size, strides=2, padding='same')(layer_input)\n",
    "        d = LeakyReLU(alpha=0.2)(d)\n",
    "        d = InstanceNormalization()(d)\n",
    "        return d\n",
    "\n",
    "    def deconv2d(layer_input, skip_input, filters, f_size=4, dropout_rate=0):\n",
    "        \"\"\"Layers used during upsampling\"\"\"\n",
    "        u = UpSampling2D(size=2)(layer_input)\n",
    "        u = Conv2D(filters, kernel_size=f_size, strides=1, padding='same', activation='relu')(u)\n",
    "        if dropout_rate:\n",
    "            u = Dropout(dropout_rate)(u)\n",
    "        u = InstanceNormalization()(u)\n",
    "        u = Concatenate()([u, skip_input])\n",
    "        return u\n",
    "\n",
    "    # Image input\n",
    "    d0 = Input(shape=img_shape)\n",
    "\n",
    "    # Downsampling\n",
    "    d1 = conv2d(d0, gf)\n",
    "    d2 = conv2d(d1, gf*2)\n",
    "    d3 = conv2d(d2, gf*4)\n",
    "    d4 = conv2d(d3, gf*8)\n",
    "\n",
    "    # Upsampling\n",
    "    u1 = deconv2d(d4, d3, gf*4)\n",
    "    u2 = deconv2d(u1, d2, gf*2)\n",
    "    u3 = deconv2d(u2, d1, gf)\n",
    "\n",
    "    u4 = UpSampling2D(size=2)(u3)\n",
    "    output_img = Conv2D(channels, kernel_size=4, strides=1, padding='same', activation='tanh')(u4)\n",
    "\n",
    "    return Model(d0, output_img)\n",
    "\n",
    "def build_discriminator():\n",
    "    def d_layer(layer_input, filters, f_size=4, normalization=True):\n",
    "        \"\"\"Discriminator layer\"\"\"\n",
    "        d = Conv2D(filters, kernel_size=f_size, strides=2, padding='same')(layer_input)\n",
    "        d = LeakyReLU(alpha=0.2)(d)\n",
    "        if normalization:\n",
    "            d = InstanceNormalization()(d)\n",
    "        return d\n",
    "\n",
    "    img = Input(shape=img_shape)\n",
    "\n",
    "    d1 = d_layer(img, df, normalization=False)\n",
    "    d2 = d_layer(d1, df*2)\n",
    "    d3 = d_layer(d2, df*4)\n",
    "    d4 = d_layer(d3, df*8)\n",
    "\n",
    "    validity = Conv2D(1, kernel_size=4, strides=1, padding='same')(d4)\n",
    "\n",
    "    return Model(img, validity)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "然后，构建生成模型$G，F$和对应的判别模型$D_Y, D_X$："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from keras.optimizers import Adam\n",
    "\n",
    "optimizer = Adam(0.0002, 0.5)\n",
    "\n",
    "# 构建判别模型Dy和Dx\n",
    "Dx = build_discriminator()\n",
    "Dy = build_discriminator()\n",
    "\n",
    "Dx.compile(loss='mse',\n",
    "    optimizer=optimizer,\n",
    "    metrics=['accuracy'])\n",
    "Dy.compile(loss='mse',\n",
    "    optimizer=optimizer,\n",
    "    metrics=['accuracy'])\n",
    "\n",
    "# 构建生成模型G和F\n",
    "G = build_generator()\n",
    "F = build_generator()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "接下来，我们定义整个CycleGAN的计算图关系。在训练GAN时，我们用均方误差(Mean Square Error)损失函数来替代论文中的$log$似然损失，根据Xudong Mao等人的[论文](https://arxiv.org/pdf/1611.04076.pdf)，使用L2 loss（也称Least Square Errors）比典型的$log$似然更加高效。L2损失的定义为真实值和预测值差的平方和：\n",
    "\n",
    "$$ LS = \\sum_{i=1}^{n} (G(x) - x)^{2} $$\n",
    "\n",
    "而MSE为对平方和再求平均数：\n",
    "\n",
    "$$ MSE =  \\frac{1}{N}  \\sum_{i=1}^{n} (G(x) - x)^{2} $$\n",
    "\n",
    "在我们使用的框架Keras中，并没有直接提供L2损失函数，而是提供了MSE，常数在求解时候可以忽略，因此不影响GAN的训练。\n",
    "\n",
    "同样，对cycle consistency loss和identity loss，Keras也没有直接提供L1损失函数，我们也使用MAE(Mean Average Error)替代。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 数据空间X和Y的输入\n",
    "img_X = Input(shape=img_shape)\n",
    "img_Y = Input(shape=img_shape)\n",
    "\n",
    "# 生成伪造的Y和X\n",
    "fake_Y = G(img_X)\n",
    "fake_X = F(img_Y)\n",
    "# 将伪造的数据再映射回原来的数据空间\n",
    "reconstr_X = F(fake_Y)\n",
    "reconstr_Y = G(fake_X)\n",
    "\n",
    "# 构建恒等映射，让生成模型在生成数据时不会过度改动原图特征\n",
    "img_X_id = F(img_X)\n",
    "img_Y_id = G(img_Y)\n",
    "\n",
    "# 冻结判别模型，在combined模型中，只训练生成模型\n",
    "Dx.trainable = False\n",
    "Dy.trainable = False\n",
    "\n",
    "# 判别模型对伪造数据进行判定，的判别结果 Discriminators determines validity of translated images\n",
    "valid_X = Dx(fake_X)\n",
    "valid_Y = Dy(fake_Y)\n",
    "\n",
    "# 生成模型在combined model中被训练。\n",
    "combined = Model(inputs=[img_X, img_Y],\n",
    "                 outputs=[valid_X, valid_Y,\n",
    "                          reconstr_X, reconstr_Y,\n",
    "                          img_X_id, img_Y_id])\n",
    "\n",
    "# 我们使用MSE损失函数代替GAN论文中的log似然\n",
    "# 用MAE损失函数代替循环一致性的L1损失函数\n",
    "combined.compile(loss=['mse', 'mse',\n",
    "                       'mae', 'mae',\n",
    "                       'mae', 'mae'],\n",
    "                 loss_weights=[1, 1,\n",
    "                               lambda_cycle, lambda_cycle,\n",
    "                               lambda_id, lambda_id ],\n",
    "                 optimizer=optimizer)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 训练模型\n",
    "\n",
    "训练之前，我们先定义一个助手函数，用来随机生成数据，观察模型效果"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.cyclegan.data_loader import DataLoader\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "\n",
    "# plt.figure(figsize=(20, 10));\n",
    "\n",
    "def sample_train_results():\n",
    "    dataset_name = 'apple2orange'\n",
    "    dataloader = DataLoader(dataset_name=dataset_name,\n",
    "                                  img_res=(img_rows, img_cols))\n",
    "\n",
    "    r, c = 2, 3\n",
    "\n",
    "    imgs_X = dataloader.load_data(domain=\"A\", batch_size=1, is_testing=True)\n",
    "    imgs_Y = dataloader.load_data(domain=\"B\", batch_size=1, is_testing=True)\n",
    "\n",
    "\n",
    "    # Translate images to the other domain\n",
    "    # 图片转换\n",
    "    fake_Y = G.predict(imgs_X)\n",
    "    fake_X = F.predict(imgs_Y)\n",
    "\n",
    "    # 重新将图片转换回原来的数据空间\n",
    "    reconstr_X = F.predict(fake_Y)\n",
    "    reconstr_Y = G.predict(fake_X)\n",
    "\n",
    "    gen_imgs = np.concatenate([imgs_X, fake_Y, reconstr_X, imgs_Y, fake_X, reconstr_Y])\n",
    "\n",
    "    # Rescale images 0 - 1\n",
    "    gen_imgs = 0.5 * gen_imgs + 0.5\n",
    "\n",
    "    titles = ['Original', 'Translated', 'Reconstructed']\n",
    "\n",
    "    fig, axs = plt.subplots(r, c)\n",
    "\n",
    "    cnt = 0\n",
    "    for i in range(r):\n",
    "        for j in range(c):\n",
    "            axs[i,j].imshow(gen_imgs[cnt]);\n",
    "            axs[i, j].set_title(titles[j])\n",
    "            axs[i,j].axis('off')\n",
    "            cnt += 1\n",
    "    \n",
    "    plt.show();   "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "下面，我们训练模型，并定期显示模型训练的效果，训练模型大概花费6分钟。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "epochs = 3\n",
    "batch_size = 1\n",
    "sample_results_interval = 200 # 每训练100个batch观察一次生成模型的结果\n",
    "\n",
    "valid = np.ones((batch_size,) + disc_patch)\n",
    "fake = np.zeros((batch_size,) + disc_patch)\n",
    "\n",
    "dataset_name = 'apple2orange'\n",
    "dataloader = DataLoader(dataset_name=dataset_name,\n",
    "                              img_res=(img_rows, img_cols))\n",
    "\n",
    "import datetime\n",
    "\n",
    "start_time = datetime.datetime.now()\n",
    "\n",
    "for epoch in range(epochs):\n",
    "    for batch_i, (imgs_X, imgs_Y) in enumerate(dataloader.load_batch(batch_size)):\n",
    "        # 训练判别模型d_X, d_Y\n",
    "\n",
    "        \n",
    "        # 首先，生成伪造的Y'=G(x)和伪造的X'=F(y)\n",
    "        fake_Y = G.predict(imgs_X)\n",
    "        fake_X = F.predict(imgs_Y)\n",
    "\n",
    "        # 用真实数据和伪造数据训练判别模型\n",
    "        Dx_loss_real = Dx.train_on_batch(imgs_X, valid)\n",
    "        Dx_loss_fake = Dy.train_on_batch(fake_X, fake)\n",
    "        Dx_loss = 0.5 * np.add(Dx_loss_real, Dx_loss_fake)\n",
    "\n",
    "        Dy_loss_real = Dy.train_on_batch(imgs_Y, valid)\n",
    "        Dy_loss_fake = Dy.train_on_batch(fake_Y, fake)\n",
    "        Dy_loss = 0.5 * np.add(Dy_loss_real, Dy_loss_fake)\n",
    "\n",
    "        # 对Dx和Dx损失求平均，作为生成模型的整体损失\n",
    "        d_loss = 0.5 * np.add(Dx_loss, Dy_loss)\n",
    "\n",
    "        # 在combined model中训练生成模型\n",
    "        g_loss = combined.train_on_batch([imgs_X, imgs_Y],\n",
    "                                         [valid, valid,\n",
    "                                          imgs_X, imgs_Y,\n",
    "                                          imgs_X, imgs_Y])\n",
    "\n",
    "        elapsed_time = datetime.datetime.now() - start_time\n",
    "        \n",
    "        if batch_i % sample_results_interval == 0:\n",
    "            print (\"[Epoch %d/%d] [Batch %d/%d] [D loss: %f, acc: %3d%%] [G loss: %05f, adv: %05f, recon: %05f, id: %05f] time: %s \" \\\n",
    "                                                                    % ( epoch+1, epochs,\n",
    "                                                                        batch_i, dataloader.n_batches,\n",
    "                                                                        d_loss[0], 100*d_loss[1],\n",
    "                                                                        g_loss[0],\n",
    "                                                                        np.mean(g_loss[1:3]),\n",
    "                                                                        np.mean(g_loss[3:5]),\n",
    "                                                                        np.mean(g_loss[5:6]),\n",
    "                                                                        elapsed_time))\n",
    "            sample_train_results()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### CycleGAN训练总结\n",
    "\n",
    "\n",
    "可以看到，刚开始训练时，生成的图片包含较多的噪点，很容易判别。而经过若干论训练，整个CycleGAN已经学习到苹果和橘子的主要特征，包括颜色、纹理，其中视觉上最明显的特征是颜色。在我们的样本中，苹果多为红色，而橘子多为橙色，在转换时，颜色特征的转换非常明显，在一些训练情况较好权重下，从橘子转换为苹果时，也可以看到生成的水果纹理比原图更光滑。\n",
    "\n",
    "\n",
    "此外，在训练过程中，我们还发现了一些有趣的现象。在一次训练中，橘子样本图片的内容为案板上摆放的橘子和紫色的茄子，模型在训练时将紫色的茄子当成橘子的特征并训练得到了较高的权重，因此后面若干轮的训练中，都把紫色作为一个主要特征，应用到图片的生成和双向重建中，并一直沿着这个优化的方向延续到训练结束。对于这种情况，我们应该适当的提前终止训练，并剔除掉数据集中特征分散、不明确的样本。\n",
    "\n",
    "\n",
    "![cyclegan-bad-training](https://modelarts-labs-bj4.obs.cn-north-4.myhuaweicloud.com/course/ai_in_action/2021/GAN/GAN/img/cyclegan-bad-training.png)\n",
    "> 训练过程中，遇到特征分散的样本并且模型恰好按照错误的特征进行训练\n",
    "\n",
    "\n",
    "\n",
    "由于CycleGAN中的模型较多，模型间关系也比较繁杂，模型之间梯度下降的方向不易协调，因此很难用一个指标去衡量整个CycleGAN的模型性能。CycleGAN的作者在公布的Github项目说明中也阐述了这一点，并针对CycleGAN不易训练的特点给出了[若干提示](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/tips.md)，建议开发者周期性的用训练好的参数去生成数据并观察结果。此外，行业中也有一些相关的优化技巧的讨论，如在生成伪造数据时加入高斯分布以避免判别模型捕捉到生成图像颜色值的数字特征，以及[label smoothing](https://arxiv.org/abs/1906.02629)技术等。\n",
    "\n",
    "\n",
    "- **扩展1**：参考CycleGAN作者给出的训练建议，以及业界相关的优化技巧，优化CycleGAN的训练\n",
    "- **扩展2**：阅读src/cyclegan/目录下的相关代码和脚本，尝试下载其他数据集训练CycleGAN并观察模型效果"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "至此，本案例完成。"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "TensorFlow-1.13.1",
   "language": "python",
   "name": "tensorflow-1.13.1"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
